package async;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle;
import jdk.internal.vm.Continuation;
import jdk.internal.vm.ContinuationScope;
import util.UnsafeUtil;

// java --add-opens java.base/jdk.internal.vm=ALL-UNNAMED -Xlog:gc=info,gc+heap=info:gc_loom.log:time async.Loom4Test
// 886 ms, r = 4995000000
public class Loom4Test {
	private static final ContinuationScope cs = new ContinuationScope("VirtualThreads");
	private static final long CONT_OFFSET;
	private static final MethodHandle YIELD0;
	private static final MethodHandle ENTER_SPECIAL;
	private static final int N1 = 100;
	private static final int N2 = 100;
	private static final int N3 = 1000;
	private static final int N = N1 * N2 * N3;
	private static int s_i;

	static {
		try {
			CONT_OFFSET = UnsafeUtil.objectFieldOffset(Thread.class, "cont");
			var lookup = MethodHandles.lookup();
			YIELD0 = lookup.unreflect(UnsafeUtil.getMethod(Continuation.class, "yield0",
					ContinuationScope.class, Continuation.class));
			ENTER_SPECIAL = lookup.unreflect(UnsafeUtil.getMethod(Continuation.class, "enterSpecial",
					Continuation.class, boolean.class, boolean.class));
		} catch (ReflectiveOperationException e) {
			throw new RuntimeException(e);
		}
	}

	@SuppressWarnings("removal")
	static void pause() {
		var c = (Continuation)UnsafeUtil.unsafe.getObject(Thread.currentThread(), CONT_OFFSET);
		try {
			@SuppressWarnings("unused")
			var __ = (boolean)YIELD0.invokeExact(c, cs, (Continuation)null);
			VarHandle.releaseFence(); // needed to prevent certain transformations by the compiler
		} catch (Throwable e) {
			throw new RuntimeException(e);
		} finally {
			UnsafeUtil.unsafe.putObject(Thread.currentThread(), CONT_OFFSET, c);
		}
	}

	static void resume(Continuation c) {
		if (c.isDone())
			throw new IllegalStateException("Continuation terminated");
		try {
			ENTER_SPECIAL.invokeExact(c, true, false);
		} catch (Throwable e) {
			throw new RuntimeException(e);
		}
	}

	private static void f3() {
		for (int i = 0; i < N3; i++) {
			s_i = i;
			pause();
		}
	}

	private static void f2() {
		for (int i = 0; i < N2; i++)
			f3();
	}

	private static void f1() {
		for (int i = 0; i < N1; i++)
			f2();
	}

	public static void main(String[] args) {
		long t = System.nanoTime();
		var c = new Continuation(cs, Loom4Test::f1);
		c.run();
		long r = 0;
		for (int i = 0; ; ) {
			r += s_i;
			if (++i >= N)
				break;
			resume(c);
		}
		System.out.println((System.nanoTime() - t) / 1_000_000 + " ms, r = " + r);
	}
}
